{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f30a6362-caea-4916-b796-0fbab99b41b1",
   "metadata": {},
   "source": [
    "## Retrieve and Rerank"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37f6f8d1-173d-492a-bf80-851f11071315",
   "metadata": {},
   "source": [
    "In this example we will:\n",
    "* index a BEIR dataset to Elasticsearch\n",
    "* retrieve data with BM25\n",
    "* optimize relevance with a reranking module running locally to our machine\n",
    "\n",
    "Regarding the last point, even though we are going to focus on small-size reranking modules it would be beneficial to run this notebook on a machine with access to GPUs to speed up the execution. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a23ca995-c54a-4146-b7ca-e53952cb9a3a",
   "metadata": {},
   "source": [
    "## Requirements\n",
    "\n",
    "For this notebook, you will need an **Elastic deployment**, we will be using [Elastic Cloud](https://www.elastic.co/guide/en/cloud/current/ec-getting-started.html) (if you don't have a deployment please see below to setup a free trial), **Python 3.10.x** or later and some **Python dependencies**:\n",
    "- `elasticsearch` (Elastic's Python client)\n",
    "- `sentence-transformers` (to load the reranking module locally)\n",
    "- `datasets` (Hugginface's library to download datasets with minimal effort)\n",
    "- `pytrec_eval` (Needed to compute accuracy scores such as `nDCG@10`)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80658d09-bb17-4a50-b2c1-989d2e3dd2b7",
   "metadata": {},
   "source": [
    "## Create Elastic Cloud deployment\n",
    "\n",
    "If you don't have an Elastic Cloud deployment, sign up [here](https://cloud.elastic.co/registration?onboarding_token=vectorsearch&utm_source=github&utm_content=elasticsearch-labs-notebook) for a free trial.\n",
    "Once logged in to your Elastic Cloud account, go to the [Create deployment](https://cloud.elastic.co/deployments/create) page and select **Create deployment**. Leave all settings with their default values.\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6836fec-ccd0-4fab-981c-f76f5ba7113e",
   "metadata": {},
   "source": [
    "## Installing packages\n",
    "\n",
    "Let's start by installing the necessary Python libraries (preferably in a virtual environment)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5a56591-4d9d-435b-b165-f9fbfa5615f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U elasticsearch sentence-transformers datasets pytrec_eval"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfda1967-8feb-400e-b125-dc8e2c349467",
   "metadata": {},
   "source": [
    "and let's gradually build our code structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c5c76bc-aed0-4e44-b0a7-724470cbb7ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from getpass import getpass\n",
    "from typing import Any, Union\n",
    "\n",
    "from datasets.arrow_dataset import Dataset\n",
    "from datasets.dataset_dict import DatasetDict, IterableDatasetDict\n",
    "from datasets.iterable_dataset import IterableDataset\n",
    "from elasticsearch import Elasticsearch\n",
    "from elasticsearch.helpers import bulk\n",
    "from sentence_transformers import CrossEncoder\n",
    "from tqdm import tqdm\n",
    "import datasets\n",
    "import numpy as np\n",
    "import pytrec_eval"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4128f7f-7ba2-406f-ba5d-435dd4a241f2",
   "metadata": {},
   "source": [
    "Before we dive deeper into the code, let's set the dataset name as a constant variable in our script. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f0a0d8-0d4c-4545-8c43-ca29a579fe62",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET = \"trec-covid\"\n",
    "INDEX_NAME = f\"reranking-test-{DATASET}\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "346a9c62-7e78-460c-938e-009eb6c45368",
   "metadata": {},
   "source": [
    "Let us also define once the necessay credentials required to access the Elastic Cloud deployment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "180ee614-224a-4a76-b33b-3ef38422e153",
   "metadata": {},
   "outputs": [],
   "source": [
    "ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID: \")\n",
    "ELASTIC_API_KEY = getpass(\"Elastic Api Key: \")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ea552d3-5f15-421d-9119-6c06a386da69",
   "metadata": {},
   "source": [
    "and initialize the Elasticseach Python client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a05f9722-ebc1-43fc-9fa4-c50ef72ea287",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = Elasticsearch(\n",
    "    cloud_id=ELASTIC_CLOUD_ID,\n",
    "    api_key=ELASTIC_API_KEY,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4549ab8a-6add-4a9d-a6c9-d1391de914a3",
   "metadata": {},
   "source": [
    "### Test the client\n",
    "\n",
    "Before you continue, confirm that the client has connected with this test.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0336efb4-5d77-46e4-8d93-ef03b2de1b93",
   "metadata": {},
   "outputs": [],
   "source": [
    "client_info = client.info()\n",
    "\n",
    "f\"Successfully connected to cluster {client_info['cluster_name']} (version {client_info['version']['number']})\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87eeef16-c040-4760-9be6-517fc6eefbac",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4de131fe-e8ec-40a2-92aa-5765235f01a9",
   "metadata": {},
   "source": [
    "## Helper functions\n",
    "\n",
    "In this section we define some helper functions to increase the readability of our code.\n",
    "\n",
    "Let's start with the functions that will handle the interaction with our Elastic Cloud deployment such as: \n",
    "- creating an index\n",
    "- storing the documents\n",
    "- retrieving documents with BM25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "073294e4-8893-4c0a-9e80-7f34f1ea81c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_index(es_client: Elasticsearch, name: str, analyzer: str = \"english\"):\n",
    "    \"\"\"\n",
    "    Creating an index into our deployment\n",
    "\n",
    "    Args:\n",
    "        `es_client`: An instance of a Python Elasticsearch client\n",
    "        `analyzer`: A string identifier of the language analyzer to be used. By default we use `english`\n",
    "            (more details at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-lang-analyzer.html)\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "\n",
    "    # we store `title` & `text` into separate fields and\n",
    "    _mappings = {\n",
    "        \"properties\": {\n",
    "            \"title\": {\"type\": \"text\", \"analyzer\": analyzer},\n",
    "            \"txt\": {\"type\": \"text\", \"analyzer\": analyzer},\n",
    "        }\n",
    "    }\n",
    "\n",
    "    # create an index with the specified name\n",
    "    es_client.options(ignore_status=[400]).indices.create(\n",
    "        index=name,\n",
    "        settings={\"number_of_shards\": 1},\n",
    "        mappings=_mappings,\n",
    "    )\n",
    "\n",
    "\n",
    "def index_corpus(\n",
    "    corpus: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],\n",
    "    index_name: str,\n",
    "    es_client: Elasticsearch,\n",
    "):\n",
    "    \"\"\"\n",
    "    Pushing documents over to our index\n",
    "\n",
    "    Args:\n",
    "        `corpus`: The corpus of the dataset we have selected. It's a Huggingface dataset with the three fields (`_id`, `title`, `text`)\n",
    "        `index_name`: The name of the Elasticsearch index\n",
    "        `es_client`: An instance of a Python Elasticsearch client\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "\n",
    "    def get_iterable():\n",
    "        for docid, doc_title, doc_txt in tqdm(\n",
    "            zip(corpus[\"_id\"], corpus[\"title\"], corpus[\"text\"]), total=corpus.num_rows\n",
    "        ):\n",
    "            yield {\n",
    "                \"_id\": docid,\n",
    "                \"_op_type\": \"index\",\n",
    "                \"title\": doc_title,\n",
    "                \"txt\": doc_txt,\n",
    "            }\n",
    "\n",
    "    # and bulk index them\n",
    "    bulk(client=es_client, index=index_name, actions=get_iterable(), max_retries=3)\n",
    "\n",
    "    # making sure that the index has been refreshed\n",
    "    es_client.indices.refresh(index=index_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f48c584-d9f9-42f6-8892-52705cddc7de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrieve(\n",
    "    queries: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],\n",
    "    es_client: Elasticsearch,\n",
    "    index_name: str,\n",
    "    size: int = 10,\n",
    "    batch_size: int = 32,\n",
    "):\n",
    "    \"\"\"\n",
    "    Retrieve docs from the index by matching title, txt separately\n",
    "    Args:\n",
    "        `queries`: The queries of the dataset we have selected. It's a Huggingface dataset with the two fields (`_id`, `text`)\n",
    "        `es_client`: An instance of a Python Elasticsearch client\n",
    "        `index_name`: The name of the Elasticsearch index\n",
    "        `size`: The (maximum) number of documents that we will retrieve per query\n",
    "        `batch_size`: It represents the number of queries we can send per request.\n",
    "\n",
    "    Returns:\n",
    "        A nested dictionary where the outer key is the \"query id\" that points to (<doc_id>, <BM25-score>) key-value pairs e.g.\n",
    "        {\"my_query_id_1\": {\"my_doc_1\": 23.5, \"my_doc_2\": 11.33}, \"my_query_id_22\": {\"my_doc_3\": 20.5, \"my_doc_4\": 4.3}, ...}\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def generate_request(query_text: str):\n",
    "        \"\"\"Create the request body for the ES requests\"\"\"\n",
    "        return {\n",
    "            \"_source\": False,\n",
    "            \"query\": {\n",
    "                \"multi_match\": {\n",
    "                    \"query\": query_text,\n",
    "                    \"type\": \"best_fields\",\n",
    "                    \"fields\": [\"title\", \"txt\"],\n",
    "                    \"tie_breaker\": 0.5,\n",
    "                }\n",
    "            },\n",
    "            \"size\": size,\n",
    "        }\n",
    "\n",
    "    def retrieve_batch(query_ids, es_requests):\n",
    "        \"\"\"Get docs for a mini-batch of requests\"\"\"\n",
    "        batch_dict = dict()\n",
    "        kwargs: dict[str, Any] = {\n",
    "            \"index\": index_name,\n",
    "            \"search_type\": \"dfs_query_then_fetch\",\n",
    "        }\n",
    "        try:\n",
    "            es_response = es_client.msearch(searches=es_requests, **kwargs)\n",
    "            for qid, resp in zip(query_ids, es_response[\"responses\"]):\n",
    "                batch_dict[qid] = {\n",
    "                    hit[\"_id\"]: hit[\"_score\"] for hit in resp[\"hits\"][\"hits\"]\n",
    "                }\n",
    "        except Exception as e:\n",
    "            print(str(e))\n",
    "        return batch_dict\n",
    "\n",
    "    qids, requests = [], []\n",
    "    es_responses = dict()\n",
    "\n",
    "    for query in queries:\n",
    "        qids.append(query[\"_id\"])\n",
    "        requests.append({})\n",
    "        requests.append(generate_request(query[\"text\"]))\n",
    "\n",
    "        # retrieve in batches\n",
    "        if len(qids) == batch_size:\n",
    "            es_responses.update(retrieve_batch(qids, requests))\n",
    "            qids = []\n",
    "            requests = []\n",
    "\n",
    "    # check for leftovers\n",
    "    if len(qids) > 0:\n",
    "        es_responses.update(retrieve_batch(qids, requests))\n",
    "        qids, requests = [], []\n",
    "\n",
    "    return es_responses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2503a091-4300-412f-b0a8-96e762e763fb",
   "metadata": {},
   "source": [
    "Then, we move to functions that rely on Hugginface's `datasets` library to fetch the `corpus`, `queries` and `qrels` files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d2a0d8e-d6f5-4f77-a2fb-1d554bcc3bd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_corpus(\n",
    "    dataset_name: str,\n",
    ") -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]:\n",
    "    \"\"\"\n",
    "    Download corpus from Huggingface\n",
    "    Args:\n",
    "        `dataset_name`: The name of the BEIR dataset that we have selected\n",
    "    Returns:\n",
    "        An instance of a Hugggingface dataset\n",
    "    \"\"\"\n",
    "\n",
    "    mteb_dataset_name = f\"mteb/{dataset_name}\"\n",
    "\n",
    "    # Dataset({\n",
    "    #     features: ['_id', 'title', 'text'],\n",
    "    #     num_rows: 25657\n",
    "    # })\n",
    "    corpus = datasets.load_dataset(mteb_dataset_name, \"corpus\", split=\"corpus\")\n",
    "\n",
    "    return corpus\n",
    "\n",
    "\n",
    "def download_queries_and_qrels(dataset_name: str):\n",
    "    \"\"\"\n",
    "    Download queries, qrels from Huggingface\n",
    "    Args:\n",
    "        `dataset_name`: The name of the BEIR dataset that we have selected\n",
    "    Returns:\n",
    "        A tuple of: (<an instance of a Hugggingface dataset>, <a dictionary holding the qrels information>)\n",
    "    \"\"\"\n",
    "\n",
    "    mteb_dataset_name = f\"mteb/{dataset_name}\"\n",
    "    qrels_raw = datasets.load_dataset(\n",
    "        mteb_dataset_name,\n",
    "        \"default\",\n",
    "        split=\"test\" if dataset_name != \"msmarco\" else \"dev\",\n",
    "    )\n",
    "\n",
    "    # convert to `pytrec_eval` compatible format\n",
    "    qrels = defaultdict(dict)\n",
    "    for q in qrels_raw:\n",
    "        qrels[q[\"query-id\"]][q[\"corpus-id\"]] = int(q[\"score\"])\n",
    "\n",
    "    queries = datasets.load_dataset(\n",
    "        mteb_dataset_name, \"queries\", split=\"queries\"\n",
    "    ).filter(lambda r: r[\"_id\"] in qrels)\n",
    "\n",
    "    return queries, dict(qrels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd0e0892-fcd3-44db-b7a3-d290782d19a5",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "093fa778-5563-41bb-872e-f5bbc5625a29",
   "metadata": {},
   "source": [
    "## Running the pipeline\n",
    "\n",
    "Now, we can execute the \"retrieve and rerank\" pipeline step by step\n",
    "\n",
    "### Corpus to our Elasticsearch index"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0e235ee-39f0-441d-b062-5231f70ae5d7",
   "metadata": {},
   "source": [
    "First, we create the index that will host the corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26586596-0be9-46ca-a881-b5c83b57f3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "create_index(name=INDEX_NAME, es_client=client)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339a8ea3-31ad-4fb6-8dba-a42588313fc3",
   "metadata": {},
   "source": [
    "Then, we download the corpus and push it into the index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0c2789f-b1f0-41b2-a06f-fd797e5d214e",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = download_corpus(dataset_name=DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40c84175-87f0-4507-9232-07783beef65a",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_corpus(es_client=client, corpus=corpus, index_name=INDEX_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9854e14f-5be1-437c-85cf-a65c1aa61a54",
   "metadata": {},
   "source": [
    "Let's move to the retrieval part"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e91b00db-74cb-4af2-947e-8f5885e3f584",
   "metadata": {},
   "source": [
    "### 1st stage retrieval with BM25\n",
    "\n",
    "First, we download the `test` split of the dataset we have selected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00900be1-4c48-4dde-81b0-24000d71925a",
   "metadata": {},
   "outputs": [],
   "source": [
    "queries, qrels = download_queries_and_qrels(dataset_name=DATASET)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ca5ec96-f04e-4a99-bdb8-2f9134295795",
   "metadata": {},
   "source": [
    "* The `queries` file is a Hugginface dataset with two keys ['_id', 'text'],\n",
    "* The `qrels` file contains the relationships between a `query_id` and a list of documents. We have transformed into a `pytrec_eval`-compatible format i.e. it's a nested dictionary where the outer key is the query id that points to dictionary with (`doc_id`, `score`) key-value pairs (a score >0 denotes relevance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "830ea137-10af-4a7b-8b03-a46db89399e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(queries)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7ca8e05-853c-47fb-b0f2-9ef3f6325e0d",
   "metadata": {},
   "source": [
    "Now, let's retrieve the **top-100** documents per query using BM25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d950d8a3-a614-4a07-9c61-719ae5a85de4",
   "metadata": {},
   "outputs": [],
   "source": [
    "bm25_responses = retrieve(\n",
    "    queries=queries, index_name=INDEX_NAME, size=100, es_client=client\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2acaca72-2338-42ea-9a19-f37646245166",
   "metadata": {},
   "source": [
    "And finally, let's compute the performance of BM25 on this dataset. We are using `nDCG@10` as our metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb183c94-7974-4bd8-9ed2-f1426d567592",
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify evaluator\n",
    "METRICS_TO_EVALUATE = {\"ndcg_cut_10\"}\n",
    "evaluator = pytrec_eval.RelevanceEvaluator(qrels, METRICS_TO_EVALUATE)\n",
    "\n",
    "\n",
    "# get score per query\n",
    "eval_per_query = evaluator.evaluate(bm25_responses)\n",
    "\n",
    "\n",
    "# aggregate scores across queries\n",
    "eval_scores = defaultdict(list)\n",
    "\n",
    "for _, vals in eval_per_query.items():\n",
    "    for metric, metric_score in vals.items():\n",
    "        eval_scores[metric].append(metric_score)\n",
    "\n",
    "for metric, _scores in eval_scores.items():\n",
    "    print(f\"{metric}: {np.mean(_scores)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67b32dca-62c3-4369-a4c4-bcb704717498",
   "metadata": {},
   "source": [
    "## 2nd stage reranking\n",
    "\n",
    "Now, let's move to the reranking part. In this example we are using a small cross-encoder model to optimize the ordering of our results. We will use the `sentence-transformers` library to load the model and do the scoring"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1955a3c-300e-49ca-80a8-0566f6efe713",
   "metadata": {},
   "outputs": [],
   "source": [
    "reranking_model = CrossEncoder(\"cross-encoder/ms-marco-MiniLM-L-6-v2\", max_length=512)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b87d87cd-4c87-4ec7-b3d5-e69aa3433f65",
   "metadata": {},
   "source": [
    "Some helper structures to speed up processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22040c13-4040-4ddc-b1e1-ec05662fd64b",
   "metadata": {},
   "outputs": [],
   "source": [
    "queries_dict = {q[\"_id\"]: q[\"text\"] for q in queries}\n",
    "corpus_dict = {doc[\"_id\"]: f\"{doc['title']} {doc['text']}\" for doc in corpus}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f11f5dbc-aa54-4a78-a61a-fa67b2742bf1",
   "metadata": {},
   "source": [
    "and now it's time for the reranking part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15a48d1d-e295-43ff-a9ed-273d3f21a9e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_after_reranking = dict()\n",
    "\n",
    "for qid, bm25_res in tqdm(bm25_responses.items(), total=len(bm25_responses)):\n",
    "\n",
    "    query_text = queries_dict[qid]\n",
    "    doc_ids = [doc_id for doc_id, _ in bm25_res.items()]\n",
    "    if len(doc_ids) == 0:\n",
    "        results_after_reranking[qid] = dict()\n",
    "        continue\n",
    "\n",
    "    doc_texts = [corpus_dict[doc_id] for doc_id in doc_ids]\n",
    "\n",
    "    # rescore with the reranking model\n",
    "    scores = reranking_model.predict([(query_text, doc_text) for doc_text in doc_texts])\n",
    "\n",
    "    results_after_reranking[qid] = {\n",
    "        doc_id: float(score) for doc_id, score in zip(doc_ids, scores)\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a85e0f51-fedc-4490-800c-48fd268d8db7",
   "metadata": {},
   "source": [
    "and let's calculate the metric scores for the reranked results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c145552-a372-4624-92f8-4798636dd3ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "post_reranking_eval_scores_per_query = evaluator.evaluate(results_after_reranking)\n",
    "\n",
    "post_reranking_eval_scores = defaultdict(list)\n",
    "\n",
    "for qid, vals in post_reranking_eval_scores_per_query.items():\n",
    "    for metric, metric_score in vals.items():\n",
    "        post_reranking_eval_scores[metric].append(metric_score)\n",
    "\n",
    "for metric, scores in post_reranking_eval_scores.items():\n",
    "    print(f\"{metric}: {np.mean(scores)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d35c3a11-9d85-4b5b-9c68-0add866f3700",
   "metadata": {},
   "source": [
    "which in most cases will provide a significant boost in performance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5f50faf-9192-43b6-be35-6700b740881d",
   "metadata": {},
   "source": [
    "## Bonus section\n",
    "\n",
    "### Judge rate\n",
    "Let's do some extra analysis and try to answer the question `\"How many times is an evaluator presented with (query, document) pairs for which there is no ground truth information?\"`\n",
    "In other words, we calculate the percentage of cases where the `qrels` file contains a relevance score for a particular document in the result list.\n",
    "Let's start with BM25 by focusing on the **top-10** retrieved documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5fac2ad-3780-4931-9ed2-127a804fb9f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "TOP_K = 10\n",
    "\n",
    "judge_rate_per_query = []\n",
    "\n",
    "for qid, doc_scores in bm25_responses.items():\n",
    "    top_k_doc_ids = [\n",
    "        doc_id\n",
    "        for doc_id, score in sorted(\n",
    "            doc_scores.items(), key=lambda x: x[1], reverse=True\n",
    "        )[:TOP_K]\n",
    "    ]\n",
    "    if len(top_k_doc_ids) == 0:\n",
    "        continue\n",
    "\n",
    "    nr_labeled_docs = sum(1 for doc_id in top_k_doc_ids if doc_id in qrels[qid])\n",
    "    judge_rate_per_query.append(nr_labeled_docs / len(top_k_doc_ids))\n",
    "\n",
    "print(f'\"Judge rate\" for {DATASET} is {np.mean(judge_rate_per_query) * 100.0:.3}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b35f1415-46dc-4f8e-b4ce-1c575951b7a9",
   "metadata": {},
   "source": [
    "while for the reranked documents it is:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cefae0b2-c963-4d38-b352-94428be35bf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "judge_rate_per_query = []\n",
    "\n",
    "for qid, doc_scores in results_after_reranking.items():\n",
    "    top_k_doc_ids = [\n",
    "        doc_id\n",
    "        for doc_id, score in sorted(\n",
    "            doc_scores.items(), key=lambda x: x[1], reverse=True\n",
    "        )[:TOP_K]\n",
    "    ]\n",
    "    if len(top_k_doc_ids) == 0:\n",
    "        continue\n",
    "\n",
    "    nr_labeled_docs = sum(1 for doc_id in top_k_doc_ids if doc_id in qrels[qid])\n",
    "    judge_rate_per_query.append(nr_labeled_docs / len(top_k_doc_ids))\n",
    "\n",
    "print(\n",
    "    f'\"Judge rate\" for {DATASET} (reranked) is {np.mean(judge_rate_per_query) * 100.0:.3}%'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1eae7142374e2f7a",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "392ab084d5bc024",
   "metadata": {},
   "source": [
    "### Confidence intervals"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fb391b00a7efd58",
   "metadata": {},
   "source": [
    "In this section we will briefly touch upon the concepts of `confidence intervals` and `statistical significance` and we will see how we can use them to determine whether improvements in our pipelines are significant or not.\n",
    "\n",
    "We can think of it as follows: Our goal is to estimate the performance of our pipeline (retrieval and/or reranking) on a target corpus. Ideally, we would like to have access to **all** queries that our end-users will run against it but of course this is impossible. Instead, we have the set of test queries provided by the benchmark and we implicitly assume that the performance on this set can act as an accurate proxy of the overall performance (in the ideal scenario).\n",
    "\n",
    "But we can make some extra assumptions to increase the reliability of our analysis. [Confidence intervals](https://en.wikipedia.org/wiki/Confidence_interval), a concept from statistical theory, give us a tool to handle our uncertainty. By setting a certain level of confidence, let's go with 95% in this example, we can derive a range of values that will likely contain the parameter of interest (here the performance in the **ideal** scenario). In other words, if we repeated the same process an infinite number of times (by drawing different test sets) we could be confident that in 95% of them the confidence interval would encompass the true value.\n",
    "\n",
    "The code below shows an example of deriving confidence intervals using [bootstrapping](https://en.wikipedia.org/wiki/Bootstrapping_\\(statistics\\)) combined with the `percentile` method. It should be noted that this statistic is affected a lot by the number and the variability of queries in the dataset i.e. smaller confidence intervals are expected for larger query sets and vice versa\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b5563e681b4236",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ci_with_bootstrapping(scores: list, nr_bootstraps=1000, percentile=95):\n",
    "    \"\"\"\n",
    "    Compute confidence intervals using bootstrapping and the percentile method\n",
    "    Args:\n",
    "        `scores`: The list of scores to be averaged\n",
    "        `nr_bootstraps`: The number of bootstrap samples to collect\n",
    "        `percentile`: The type of confidence interval to compute. It should be a number in (0, 100),\n",
    "            by default it computes 95% CI\n",
    "    Returns:\n",
    "        The confidence interval\n",
    "    \"\"\"\n",
    "    estimates = []\n",
    "    for _ in range(nr_bootstraps):\n",
    "        sample = np.random.choice(scores, len(scores), replace=True)\n",
    "        estimates.append(np.mean(sample))\n",
    "\n",
    "    half_percentile = (100.0 - percentile) / 2.0\n",
    "    return np.percentile(estimates, [half_percentile, 100.0 - half_percentile])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "701ed82491660fd6",
   "metadata": {},
   "source": [
    "and we can apply it to our results as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f2f080185190b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "ndcg_scores = post_reranking_eval_scores[\"ndcg_cut_10\"]\n",
    "get_ci_with_bootstrapping(ndcg_scores, percentile=95, nr_bootstraps=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be4d62dfde0c588b",
   "metadata": {},
   "source": [
    "The way to interpret this would be to say that we are 95% confident that the `nDCG@10` score in the ideal scenario lies within that interval\n",
    "\n",
    "Confidence intervals can be used in the context of significance testing. For example, if we wanted to compare two pipelines (retrieval and/or reranking) on a dataset one way to do this would be to:\n",
    "* Decide on a confidence level (e.g. 90% or 95%)\n",
    "* Compute confidence intervals for the performance of model A \n",
    "* Compute confidence intervals for the performance of model B\n",
    "* Check whether there is an overlap between the two intervals. \n",
    "\n",
    "In the last step, if there is **no** overlap we can say that the observed difference in performance between the two pipelines is **statistically significant**. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f8bf706821ab252",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
